{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running as a Jupyter notebook - intended for development only!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_86391/3507779555.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
      "  ipython.magic(\"load_ext autoreload\")\n",
      "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_86391/3507779555.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
      "  ipython.magic(\"autoreload 2\")\n"
     ]
    }
   ],
   "source": [
    "# NBVAL_IGNORE_OUTPUT\n",
    "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
    "import os\n",
    "\n",
    "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
    "\n",
    "try:\n",
    "    import google.colab\n",
    "    IN_COLAB = True\n",
    "    print(\"Running as a Colab notebook\")\n",
    "except:\n",
    "    IN_COLAB = False\n",
    "    print(\"Running as a Jupyter notebook - intended for development only!\")\n",
    "    from IPython import get_ipython\n",
    "\n",
    "    ipython = get_ipython()\n",
    "    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
    "    ipython.magic(\"load_ext autoreload\")\n",
    "    ipython.magic(\"autoreload 2\")\n",
    "\n",
    "\n",
    "\n",
    "if IN_COLAB or IN_GITHUB:\n",
    "    # %pip install sentencepiece # Llama tokenizer requires sentencepiece\n",
    "    %pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n",
    "    %pip install torch\n",
    "    %pip install tiktoken\n",
    "    # %pip install transformer_lens\n",
    "    %pip install transformers_stream_generator\n",
    "    # !huggingface-cli login --token NEEL'S TOKEN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TransformerLens currently supports 216 models out of the box.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformer_lens import HookedTransformer, HookedEncoderDecoder, HookedEncoder, BertNextSentencePrediction, loading\n",
    "from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer\n",
    "from typing import List\n",
    "import gc\n",
    "\n",
    "untested_models = []\n",
    "untested_models.extend(loading.OFFICIAL_MODEL_NAMES)\n",
    "\n",
    "print(\"TransformerLens currently supports \" + str(len(untested_models)) + \" models out of the box.\")\n",
    "\n",
    "GENERATE = True\n",
    "# Fill this in if you have llama weights uploaded, and you with to test those models\n",
    "LLAMA_MODEL_PATH = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mark_models_as_tested(model_set: List[str]) -> None:\n",
    "    for model in model_set:\n",
    "        untested_models.remove(model)\n",
    "\n",
    "\n",
    "def run_set(model_set: List[str], device=\"cuda\") -> None:\n",
    "    for model in model_set:\n",
    "        print(\"Testing \" + model)\n",
    "        tl_model = HookedTransformer.from_pretrained_no_processing(model, device=device)\n",
    "        if GENERATE:\n",
    "            print(tl_model.generate(\"Hello my name is\"))\n",
    "        del tl_model\n",
    "        gc.collect()\n",
    "        if IN_COLAB:\n",
    "            %rm -rf /root/.cache/huggingface/hub/models*\n",
    "\n",
    "def run_llama_set(model_set: List[str], weight_root: str, device=\"cuda\") -> None:\n",
    "    for model in model_set:\n",
    "        print(\"Testing \" + model)\n",
    "        # to run this, make sure weight root is the root that contains all models with the\n",
    "        # sub directories sharing the same name as the model in the list of models\n",
    "        tokenizer = LlamaTokenizer.from_pretrained(weight_root + model)\n",
    "        hf_model = LlamaForCausalLM.from_pretrained(weight_root + model, low_cpu_mem_usage=True)\n",
    "        tl_model = HookedTransformer.from_pretrained_no_processing(\n",
    "            model,\n",
    "            hf_model=hf_model,\n",
    "            device=device,\n",
    "            fold_ln=False,\n",
    "            center_writing_weights=False,\n",
    "            center_unembed=False,\n",
    "            tokenizer=tokenizer,\n",
    "        )\n",
    "        if GENERATE:\n",
    "            print(tl_model.generate(\"Hello my name is\"))\n",
    "        del tl_model\n",
    "        gc.collect()\n",
    "        if IN_COLAB:\n",
    "            %rm -rf /root/.cache/huggingface/hub/models*\n",
    "\n",
    "\n",
    "def run_encoder_decoder_set(model_set: List[str], device=\"cuda\") -> None:\n",
    "    for model in model_set:\n",
    "        print(\"Testing \" + model)\n",
    "        tokenizer = AutoTokenizer.from_pretrained(model)\n",
    "        tl_model = HookedEncoderDecoder.from_pretrained(model, device=device)\n",
    "        if GENERATE:\n",
    "            # Originally from the t5 demo\n",
    "            prompt = \"Hello, how are you? \"\n",
    "            inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "            input_ids = inputs[\"input_ids\"]\n",
    "            attention_mask = inputs[\"attention_mask\"]\n",
    "            decoder_input_ids = torch.tensor([[tl_model.cfg.decoder_start_token_id]]).to(input_ids.device)\n",
    "\n",
    "\n",
    "            while True:\n",
    "                logits = tl_model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n",
    "                # logits.shape == (batch_size (1), predicted_pos, vocab_size)\n",
    "\n",
    "                token_idx = torch.argmax(logits[0, -1, :]).item()\n",
    "                print(\"generated token: \\\"\", tokenizer.decode(token_idx), \"\\\", token id: \", token_idx, sep=\"\")\n",
    "\n",
    "                # append token to decoder_input_ids\n",
    "                decoder_input_ids = torch.cat([decoder_input_ids, torch.tensor([[token_idx]]).to(input_ids.device)], dim=-1)\n",
    "\n",
    "                # break if End-Of-Sequence token generated\n",
    "                if token_idx == tokenizer.eos_token_id:\n",
    "                    break\n",
    "        del tl_model\n",
    "        gc.collect()\n",
    "        if IN_COLAB:\n",
    "            %rm -rf /root/.cache/huggingface/hub/models*\n",
    "\n",
    "def run_encoder_only_set(model_set: List[str], device=\"cuda\") -> None:\n",
    "    for model in model_set:\n",
    "        print(\"Testing \" + model)\n",
    "        tl_model = HookedEncoder.from_pretrained(model, device=device)\n",
    "        tl_model_nsp = NextSentencePrediction.from_pretrained(model, device=device)\n",
    "\n",
    "        if GENERATE:\n",
    "            print(\"Testing Masked Language Modelling:\")\n",
    "            # Slightly adapted version of the BERT demo\n",
    "            prompt = \"The capital of France is [MASK].\"\n",
    "\n",
    "            prediction = tl_model(prompt, return_type=\"predictions\")\n",
    "\n",
    "            print(f\"Prompt: {prompt}\")\n",
    "            print(f'Prediction: \"{prediction}\"')\n",
    "\n",
    "            print(\"Testing Next Sentence Prediction:\")\n",
    "            sentence_a = \"She went to the grocery store.\"\n",
    "            sentence_b = \"She bought some milk.\"\n",
    "\n",
    "            prediction = tl_model_nsp([sentence_a, sentence_b], return_type=\"predictions\")\n",
    "\n",
    "            print(f\"Sentence A: {sentence_a}\")\n",
    "            print(f\"Sentence B: {sentence_b}\")\n",
    "            print(f\"Prediction: {prediction}\")\n",
    "\n",
    "        del tl_model\n",
    "        gc.collect()\n",
    "        if IN_COLAB:\n",
    "            %rm -rf /root/.cache/huggingface/hub/models*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The following models can run in the T4 free environment\n",
    "free_compatible = [\n",
    "    \"ai-forever/mGPT\",\n",
    "    \"ArthurConmy/redwood_attn_2l\",\n",
    "    \"bigcode/santacoder\",\n",
    "    \"bigscience/bloom-1b1\",\n",
    "    \"bigscience/bloom-560m\",\n",
    "    \"distilgpt2\",\n",
    "    \"EleutherAI/gpt-neo-1.3B\",\n",
    "    \"EleutherAI/gpt-neo-125M\",\n",
    "    \"EleutherAI/gpt-neo-2.7B\",\n",
    "    \"EleutherAI/pythia-1.4b\",\n",
    "    \"EleutherAI/pythia-1.4b-deduped\",\n",
    "    \"EleutherAI/pythia-1.4b-deduped-v0\",\n",
    "    \"EleutherAI/pythia-1.4b-v0\",\n",
    "    \"EleutherAI/pythia-14m\",\n",
    "    \"EleutherAI/pythia-160m\",\n",
    "    \"EleutherAI/pythia-160m-deduped\",\n",
    "    \"EleutherAI/pythia-160m-deduped-v0\",\n",
    "    \"EleutherAI/pythia-160m-seed1\",\n",
    "    \"EleutherAI/pythia-160m-seed2\",\n",
    "    \"EleutherAI/pythia-160m-seed3\",\n",
    "    \"EleutherAI/pythia-160m-v0\",\n",
    "    \"EleutherAI/pythia-1b\",\n",
    "    \"EleutherAI/pythia-1b-deduped\",\n",
    "    \"EleutherAI/pythia-1b-deduped-v0\",\n",
    "    \"EleutherAI/pythia-1b-v0\",\n",
    "    \"EleutherAI/pythia-31m\",\n",
    "    \"EleutherAI/pythia-410m\",\n",
    "    \"EleutherAI/pythia-410m-deduped\",\n",
    "    \"EleutherAI/pythia-410m-deduped-v0\",\n",
    "    \"EleutherAI/pythia-410m-v0\",\n",
    "    \"EleutherAI/pythia-70m\",\n",
    "    \"EleutherAI/pythia-70m-deduped\",\n",
    "    \"EleutherAI/pythia-70m-deduped-v0\",\n",
    "    \"EleutherAI/pythia-70m-v0\",\n",
    "    \"facebook/opt-1.3b\",\n",
    "    \"facebook/opt-125m\",\n",
    "    \"gpt2\",\n",
    "    \"gpt2-large\",\n",
    "    \"gpt2-medium\",\n",
    "    \"gpt2-xl\",\n",
    "    \"meta-llama/Llama-3.2-1B\",\n",
    "    \"meta-llama/Llama-3.2-1B-Instruct\",\n",
    "    \"microsoft/phi-1\",\n",
    "    \"microsoft/phi-1_5\",\n",
    "    \"NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr\",\n",
    "    \"NeelNanda/Attn_Only_1L512W_C4_Code\",\n",
    "    \"NeelNanda/Attn_Only_2L512W_C4_Code\",\n",
    "    \"NeelNanda/Attn_Only_3L512W_C4_Code\",\n",
    "    \"NeelNanda/Attn_Only_4L512W_C4_Code\",\n",
    "    \"NeelNanda/GELU_1L512W_C4_Code\",\n",
    "    \"NeelNanda/GELU_2L512W_C4_Code\",\n",
    "    \"NeelNanda/GELU_3L512W_C4_Code\",\n",
    "    \"NeelNanda/GELU_4L512W_C4_Code\",\n",
    "    \"NeelNanda/SoLU_10L1280W_C4_Code\",\n",
    "    \"NeelNanda/SoLU_10L_v22_old\",\n",
    "    \"NeelNanda/SoLU_12L1536W_C4_Code\",\n",
    "    \"NeelNanda/SoLU_12L_v23_old\",\n",
    "    \"NeelNanda/SoLU_1L512W_C4_Code\",\n",
    "    \"NeelNanda/SoLU_1L512W_Wiki_Finetune\",\n",
    "    \"NeelNanda/SoLU_1L_v9_old\",\n",
    "    \"NeelNanda/SoLU_2L512W_C4_Code\",\n",
    "    \"NeelNanda/SoLU_2L_v10_old\",\n",
    "    \"NeelNanda/SoLU_3L512W_C4_Code\",\n",
    "    \"NeelNanda/SoLU_4L512W_C4_Code\",\n",
    "    \"NeelNanda/SoLU_4L512W_Wiki_Finetune\",\n",
    "    \"NeelNanda/SoLU_4L_v11_old\",\n",
    "    \"NeelNanda/SoLU_6L768W_C4_Code\",\n",
    "    \"NeelNanda/SoLU_6L_v13_old\",\n",
    "    \"NeelNanda/SoLU_8L1024W_C4_Code\",\n",
    "    \"NeelNanda/SoLU_8L_v21_old\",\n",
    "    \"Qwen/Qwen-1_8B\",\n",
    "    \"Qwen/Qwen-1_8B-Chat\",\n",
    "    \"Qwen/Qwen1.5-0.5B\",\n",
    "    \"Qwen/Qwen1.5-0.5B-Chat\",\n",
    "    \"Qwen/Qwen1.5-1.8B\",\n",
    "    \"Qwen/Qwen1.5-1.8B-Chat\",\n",
    "    \"Qwen/Qwen2-0.5B\",\n",
    "    \"Qwen/Qwen2-0.5B-Instruct\",\n",
    "    \"Qwen/Qwen2-1.5B\",\n",
    "    \"Qwen/Qwen2-1.5B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-0.5B\",\n",
    "    \"Qwen/Qwen2.5-0.5B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-1.5B\",\n",
    "    \"Qwen/Qwen2.5-1.5B-Instruct\",\n",
    "    \"Qwen/Qwen3-0.6B\",\n",
    "    \"Qwen/Qwen3-1.7B\",\n",
    "    \"roneneldan/TinyStories-1Layer-21M\",\n",
    "    \"roneneldan/TinyStories-1M\",\n",
    "    \"roneneldan/TinyStories-28M\",\n",
    "    \"roneneldan/TinyStories-2Layers-33M\",\n",
    "    \"roneneldan/TinyStories-33M\",\n",
    "    \"roneneldan/TinyStories-3M\",\n",
    "    \"roneneldan/TinyStories-8M\",\n",
    "    \"roneneldan/TinyStories-Instruct-1M\",\n",
    "    \"roneneldan/TinyStories-Instruct-28M\",\n",
    "    \"roneneldan/TinyStories-Instruct-2Layers-33M\",\n",
    "    \"roneneldan/TinyStories-Instruct-33M\",\n",
    "    \"roneneldan/TinyStories-Instruct-3M\",\n",
    "    \"roneneldan/TinyStories-Instruct-8M\",\n",
    "    \"roneneldan/TinyStories-Instuct-1Layer-21M\",\n",
    "    \"stanford-crfm/alias-gpt2-small-x21\",\n",
    "    \"stanford-crfm/arwen-gpt2-medium-x21\",\n",
    "    \"stanford-crfm/battlestar-gpt2-small-x49\",\n",
    "    \"stanford-crfm/beren-gpt2-medium-x49\",\n",
    "    \"stanford-crfm/caprica-gpt2-small-x81\",\n",
    "    \"stanford-crfm/celebrimbor-gpt2-medium-x81\",\n",
    "    \"stanford-crfm/darkmatter-gpt2-small-x343\",\n",
    "    \"stanford-crfm/durin-gpt2-medium-x343\",\n",
    "    \"stanford-crfm/eowyn-gpt2-medium-x777\",\n",
    "    \"stanford-crfm/expanse-gpt2-small-x777\",\n",
    "]\n",
    "\n",
    "if IN_COLAB:\n",
    "    run_set(free_compatible)\n",
    "\n",
    "mark_models_as_tested(free_compatible)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "paid_gpu_models = [\n",
    "    \"01-ai/Yi-6B\",\n",
    "    \"01-ai/Yi-6B-Chat\",\n",
    "    \"bigscience/bloom-1b7\",\n",
    "    \"bigscience/bloom-3b\",\n",
    "    \"bigscience/bloom-7b1\",\n",
    "    \"codellama/CodeLlama-7b-hf\",\n",
    "    \"codellama/CodeLlama-7b-Instruct-hf\",\n",
    "    \"codellama/CodeLlama-7b-Python-hf\",\n",
    "    \"EleutherAI/pythia-2.8b\",\n",
    "    \"EleutherAI/pythia-2.8b-deduped\",\n",
    "    \"EleutherAI/pythia-2.8b-deduped-v0\",\n",
    "    \"EleutherAI/pythia-2.8b-v0\",\n",
    "    \"EleutherAI/pythia-6.9b\",\n",
    "    \"EleutherAI/pythia-6.9b-deduped\",\n",
    "    \"EleutherAI/pythia-6.9b-deduped-v0\",\n",
    "    \"EleutherAI/pythia-6.9b-v0\",\n",
    "    \"facebook/opt-2.7b\",\n",
    "    \"facebook/opt-6.7b\",\n",
    "    \"google/gemma-2-2b\",\n",
    "    \"google/gemma-2-2b-it\",\n",
    "    \"google/gemma-2b\",\n",
    "    \"google/gemma-2b-it\",\n",
    "    \"google/gemma-7b\",\n",
    "    \"google/gemma-7b-it\",\n",
    "    \"meta-llama/Llama-2-7b-chat-hf\",\n",
    "    \"meta-llama/Llama-2-7b-hf\",\n",
    "    \"meta-llama/Llama-3.1-8B\",\n",
    "    \"meta-llama/Llama-3.1-8B-Instruct\",\n",
    "    \"meta-llama/Llama-3.2-3B\",\n",
    "    \"meta-llama/Llama-3.2-3B-Instruct\",\n",
    "    \"meta-llama/Meta-Llama-3-8B\",\n",
    "    \"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
    "    \"microsoft/phi-2\",\n",
    "    \"microsoft/Phi-3-mini-4k-instruct\",\n",
    "    \"mistralai/Mistral-7B-Instruct-v0.1\",\n",
    "    \"mistralai/Mistral-7B-v0.1\",\n",
    "    \"mistralai/Mistral-Nemo-Base-2407\",\n",
    "    \"mistralai/Mistral-Small-24B-Base-2501\",\n",
    "    \"Qwen/Qwen-7B\",\n",
    "    \"Qwen/Qwen-7B-Chat\",\n",
    "    \"Qwen/Qwen1.5-4B\",\n",
    "    \"Qwen/Qwen1.5-4B-Chat\",\n",
    "    \"Qwen/Qwen1.5-7B\",\n",
    "    \"Qwen/Qwen1.5-7B-Chat\",\n",
    "    \"Qwen/Qwen2-7B\",\n",
    "    \"Qwen/Qwen2-7B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-3B\",\n",
    "    \"Qwen/Qwen2.5-3B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-7B\",\n",
    "    \"Qwen/Qwen2.5-7B-Instruct\",\n",
    "    \"Qwen/Qwen3-4B\",\n",
    "    \"Qwen/Qwen3-8B\",\n",
    "    \"stabilityai/stablelm-base-alpha-3b\",\n",
    "    \"stabilityai/stablelm-base-alpha-7b\",\n",
    "    \"stabilityai/stablelm-tuned-alpha-3b\",\n",
    "    \"stabilityai/stablelm-tuned-alpha-7b\",\n",
    "]\n",
    "\n",
    "if IN_COLAB:\n",
    "    run_set(paid_gpu_models)\n",
    "\n",
    "mark_models_as_tested(paid_gpu_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "paid_cpu_models = [\n",
    "    \"EleutherAI/gpt-j-6B\",\n",
    "    \"EleutherAI/gpt-neox-20b\",\n",
    "    \"EleutherAI/pythia-12b\",\n",
    "    \"EleutherAI/pythia-12b-deduped\",\n",
    "    \"EleutherAI/pythia-12b-deduped-v0\",\n",
    "    \"EleutherAI/pythia-12b-v0\",\n",
    "    \"facebook/opt-13b\",\n",
    "    \"google/gemma-2-9b\",\n",
    "    \"google/gemma-2-9b-it\",\n",
    "    \"meta-llama/Llama-2-13b-chat-hf\",\n",
    "    \"meta-llama/Llama-2-13b-hf\",\n",
    "    \"microsoft/phi-4\",\n",
    "    \"Qwen/Qwen-14B\",\n",
    "    \"Qwen/Qwen-14B-Chat\",\n",
    "    \"Qwen/Qwen1.5-14B\",\n",
    "    \"Qwen/Qwen1.5-14B-Chat\",\n",
    "    \"Qwen/Qwen2.5-14B\",\n",
    "    \"Qwen/Qwen2.5-14B-Instruct\",\n",
    "]\n",
    "\n",
    "if IN_COLAB:\n",
    "    run_set(paid_cpu_models, \"cpu\")\n",
    "\n",
    "mark_models_as_tested(paid_cpu_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "incompatible_models = [\n",
    "    \"01-ai/Yi-34B\",\n",
    "    \"01-ai/Yi-34B-Chat\",\n",
    "    \"facebook/opt-30b\",\n",
    "    \"facebook/opt-66b\",\n",
    "    \"google/gemma-2-27b\",\n",
    "    \"google/gemma-2-27b-it\",\n",
    "    \"meta-llama/Llama-2-70b-chat-hf\",\n",
    "    \"meta-llama/Llama-3.1-70B\",\n",
    "    \"meta-llama/Llama-3.1-70B-Instruct\",\n",
    "    \"meta-llama/Llama-3.3-70B-Instruct\",\n",
    "    \"meta-llama/Meta-Llama-3-70B\",\n",
    "    \"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
    "    \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
    "    \"mistralai/Mixtral-8x7B-v0.1\",\n",
    "    \"Qwen/Qwen2.5-32B\",\n",
    "    \"Qwen/Qwen2.5-32B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-72B\",\n",
    "    \"Qwen/Qwen2.5-72B-Instruct\",\n",
    "    \"Qwen/Qwen3-14B\",\n",
    "    \"Qwen/QwQ-32B-Preview\",\n",
    "]\n",
    "\n",
    "mark_models_as_tested(incompatible_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The following models take a few extra steps to function. Check the official demo for more\n",
    "# information on how to use. 7b and 13b will work in the paid environment. 30b and 65b will not work\n",
    "# in Colab\n",
    "not_hosted_models = [\n",
    "    \"llama-7b-hf\",\n",
    "    \"llama-13b-hf\",\n",
    "    \"llama-30b-hf\",\n",
    "    \"llama-65b-hf\",\n",
    "]\n",
    "\n",
    "if LLAMA_MODEL_PATH:\n",
    "    run_llama_set(not_hosted_models, LLAMA_MODEL_PATH)\n",
    "\n",
    "mark_models_as_tested(not_hosted_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# These all work on the free version of Colab\n",
    "encoder_decoders = [\n",
    "    \"google-t5/t5-base\",\n",
    "    \"google-t5/t5-large\",\n",
    "    \"google-t5/t5-small\",\n",
    "]\n",
    "if IN_COLAB:\n",
    "    run_encoder_decoder_set(encoder_decoders)\n",
    "\n",
    "mark_models_as_tested(encoder_decoders)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This model works on the free version of Colab\n",
    "encoder_only_models = [\n",
    "    \"google-bert/bert-base-cased\",\n",
    "    \"google-bert/bert-base-uncased\",\n",
    "    \"google-bert/bert-large-cased\",\n",
    "    \"google-bert/bert-large-uncased\",\n",
    "]\n",
    "\n",
    "if IN_COLAB:\n",
    "    run_encoder_only_set(encoder_only_models)\n",
    "\n",
    "mark_models_as_tested(encoder_only_models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "broken_models = [\n",
    "    \"Baidicoot/Othello-GPT-Transformer-Lens\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Baidicoot/Othello-GPT-Transformer-Lens\n"
     ]
    }
   ],
   "source": [
    "# Any models listed in the cell below have not been tested. This should always remain blank. If your\n",
    "# PR fails due to this notebook, most likely you need to check any new model changes to ensure that\n",
    "# this notebook is up to date.\n",
    "print(*untested_models, sep=\"\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
